import torch
import torch.nn as nn
from .lora import LoRACompatibleConv
class BaseInteracteMoudle(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8):
        super(BaseInteracteMoudle, self).__init__()
        
        # 自注意力
        self.self_attention1 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        self.self_attention2 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        # 交叉注意力
        self.cross_attention1 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        self.cross_attention2 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        # 输出变换层，确保输出维度为 (b, 512, 5, 16)
        self.output_layer = nn.Linear(embed_dim, 80)  # 需要调整为合适的输出维度（80）

    def forward(self, X1, X2):
        # 输入维度: X1, X2: (b, 80, 512)
        b, dim,seqlen,joint=X1.shape
        X1=X1.reshape(b, dim,-1).transpose(1,2).contiguous()
        X2=X2.reshape(b, dim,-1).transpose(1,2).contiguous()
        self_attention_out1, _ = self.self_attention1(X1, X1, X1)
        self_attention_out2, _ = self.self_attention2(X2, X2, X2)
        cross_attention_out1, _ = self.cross_attention1(self_attention_out1, self_attention_out2, self_attention_out2)
        cross_attention_out2, _ = self.cross_attention2(self_attention_out2, self_attention_out1, self_attention_out1)
        cross_attention_out1 = cross_attention_out1.transpose(1,2).view(b, dim,seqlen,joint)  # 将 80 分解为 (5, 16)
        cross_attention_out2 = cross_attention_out2.transpose(1,2).view(b, dim,seqlen,joint)  # 将 80 分解为 (5, 16
        return cross_attention_out1,cross_attention_out2,None,None
import torch
import torch.nn as nn
    
class CrossInteracteMoudle(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8):
        super(CrossInteracteMoudle, self).__init__()
        
        # 自注意力
        self.self_attention1 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        # self.self_attention2 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        # 交叉注意力
        self.cross_attention1 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        # self.cross_attention2 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        # 输出变换层，确保输出维度为 (b, 512, 5, 16)

    def forward(self, X1, X2):
        # 输入维度: X1, X2: (b, 80, 512)
        b, dim,seqlen,joint=X1.shape
        X1=X1.reshape(b, dim,-1).transpose(1,2).contiguous()
        X2=X2.reshape(b, dim,-1).transpose(1,2).contiguous()
        self_attention_out1, _ = self.self_attention1(X1, X1, X1)
        self_attention_out2, _ = self.self_attention1(X2, X2, X2)
        cross_attention_out1, _ = self.cross_attention1(self_attention_out1, self_attention_out2, self_attention_out2)
        cross_attention_out2, _ = self.cross_attention1(self_attention_out2, self_attention_out1, self_attention_out1)
        cross_attention_out1 = cross_attention_out1.transpose(1,2).view(b, dim,seqlen,joint)  # 将 80 分解为 (5, 16)
        cross_attention_out2 = cross_attention_out2.transpose(1,2).view(b, dim,seqlen,joint)  # 将 80 分解为 (5, 16
        return cross_attention_out1,cross_attention_out2,None,None
    


class MiddleCrossInteracteMoudle(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8):
        super(MiddleCrossInteracteMoudle, self).__init__()
        
        # 自注意力
        self.self_attention1 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        # self.self_attention2 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        # 交叉注意力
        self.cross_attention1 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        # self.cross_attention2 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        # 输出变换层，确保输出维度为 (b, 512, 5, 16)

    def forward(self, X1, X2):
        # 输入维度: X1, X2: (b, 80, 512)
        b, dim,seqlen,joint=X1.shape
        X1=X1.reshape(b, dim,-1).transpose(1,2).contiguous()
        X2=X2.reshape(b, dim,-1).transpose(1,2).contiguous()
        self_attention_out1, _ = self.self_attention1(X1, X1, X1)
        self_attention_out2, _ = self.self_attention1(X2, X2, X2)
        self_out1 = self_attention_out1.transpose(1, 2).view(b, dim, seqlen, joint)
        self_out2 = self_attention_out2.transpose(1, 2).view(b, dim, seqlen, joint)
        cross_attention_out1, _ = self.cross_attention1(self_attention_out1, self_attention_out2, self_attention_out2)
        cross_attention_out2, _ = self.cross_attention1(self_attention_out2, self_attention_out1, self_attention_out1)
        cross_attention_out1 = cross_attention_out1.transpose(1,2).view(b, dim,seqlen,joint)  # 将 80 分解为 (5, 16)
        cross_attention_out2 = cross_attention_out2.transpose(1,2).view(b, dim,seqlen,joint)  # 将 80 分解为 (5, 16
        return self_out1,self_out2,cross_attention_out1,cross_attention_out2

class LORACrossInteracteMoudle(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8):
        super(LORACrossInteracteMoudle, self).__init__()
        
        # 自注意力
        self.self_attention1 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        # self.self_attention2 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        # 交叉注意力
        self.cross_attention1 =LoRACompatibleConv(
                            in_channels=2*embed_dim,
                            out_channels=2*embed_dim,
                            kernel_size=3,
                            stride=1,
                            padding=1,
                            bias=False,
                            # lora_layer=lora_layer
                        )
        # self.cross_attention2 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        # 输出变换层，确保输出维度为 (b, 512, 5, 16)

    def forward(self, X1, X2):
        # 输入维度: X1, X2: (b, 80, 512)
        b, dim,seqlen,joint=X1.shape
        X1=X1.reshape(b, dim,-1).transpose(1,2).contiguous()
        X2=X2.reshape(b, dim,-1).transpose(1,2).contiguous()
        self_attention_out1, _ = self.self_attention1(X1, X1, X1)
        self_attention_out2, _ = self.self_attention1(X2, X2, X2)
        self_out1 = self_attention_out1.transpose(1, 2).view(b, dim, seqlen, joint)
        self_out2 = self_attention_out2.transpose(1, 2).view(b, dim, seqlen, joint)
        cross_attention_out=self.cross_attention1(torch.cat([self_out1,self_out2],dim=1))
        cross_attention_out1=cross_attention_out[:,:dim,:,:]
        cross_attention_out2=cross_attention_out[:,dim:,:,:]
        # cross_attention_out1, _ = self.cross_attention1(self_attention_out1, self_attention_out2, self_attention_out2)
        # cross_attention_out2, _ = self.cross_attention1(self_attention_out2, self_attention_out1, self_attention_out1)
        # cross_attention_out1 = cross_attention_out1.transpose(1,2).view(b, dim,seqlen,joint)  # 将 80 分解为 (5, 16)
        # cross_attention_out2 = cross_attention_out2.transpose(1,2).view(b, dim,seqlen,joint)  # 将 80 分解为 (5, 16
        return self_out1,self_out2,cross_attention_out1,cross_attention_out2


